import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def plot_bandit_diffusion(all_results, projector):
    """
    Plots the results from bandit diffusion.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.
        projector (Tensor): The projector used in the algorithm.
    """
    sns.set(style="whitegrid")

    # Plot Condition vs Obtained
    plot_condition_vs_obtained(all_results)

    # Plot Manifold Distance
    all_distances = calculate_all_distances(all_results)
    plot_manifold_distance(all_distances)

    # Plot Baseline (calls existing functions)
    plot_baseline(all_results, projector)


def plot_condition_vs_obtained(all_results):
    """
    Plots Condition vs Obtained for all runs.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.
    """
    plt.figure(figsize=(10, 6))
    for run in all_results:
        condition = np.concatenate(run["condition"]).flatten()
        rewards_iterate = np.concatenate(run["rewards_iterate"]).flatten()
        plt.scatter(
            condition,
            rewards_iterate,
            alpha=0.5,
        )
    min_cond = min(
        [np.min(np.concatenate(run["condition"]).flatten()) for run in all_results]
    )
    max_cond = max(
        [np.max(np.concatenate(run["condition"]).flatten()) for run in all_results]
    )
    plt.plot(
        [min_cond, max_cond],
        [min_cond, max_cond],
        "r--",
    )
    plt.xlabel("Condition")
    plt.ylabel("Obtained")
    plt.title("Condition vs Obtained (All Runs)")
    plt.savefig("condition_vs_obtained_all_runs.png")
    plt.close()


def calculate_all_distances(all_results):
    """
    Calculates the distances from all runs.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.

    Returns:
        np.ndarray: Array of distances for all runs.
    """
    all_distances = [np.concatenate(run["distances"]).flatten() for run in all_results]
    all_distances = np.array(all_distances)
    return all_distances


def plot_manifold_distance(all_distances):
    """
    Plots the manifold distance over iterations.

    Args:
        all_distances (np.ndarray): Array of distances for all runs.
    """
    plt.figure(figsize=(10, 6))
    mean_distances = np.mean(all_distances, axis=0)
    std_distances = np.std(all_distances, axis=0)

    plt.plot(np.arange(len(mean_distances)), mean_distances, label="Mean Distance")
    plt.fill_between(
        np.arange(len(mean_distances)),
        mean_distances + std_distances,
        mean_distances - std_distances,
        alpha=0.3,
        label=r"$\pm$ std",
    )
    plt.xlabel("Iteration")
    plt.ylabel("Distance to Manifold")
    plt.legend()
    plt.savefig("manifold_distance.png")
    plt.close()


def plot_baseline(all_results, projector, mode: str = ""):
    """
    Plots the baseline results from Thompson sampling.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.
        projector (Tensor): The projector used in the algorithm.
    """
    sns.set(style="whitegrid")

    # Calculate and plot cumulative regret
    standardized_regrets = calculate_standardized_regrets(all_results)
    plot_cumulative_regret(standardized_regrets, mode)

    # Calculate and plot reward convergence
    standardized_rewards, mean_max = calculate_standardized_rewards(all_results)
    plot_rewards_convergence(standardized_rewards, mean_max, mode)

    # Calculate theta differences
    (
        theta_diffs_stand_all,
        theta_mean_diffs_stand_all,
        theta_diffs_inf_all,
        theta_mean_diffs_inf_all,
    ) = calculate_theta_differences(all_results, projector)

    # Plot theta convergence in L2 and L∞ norms
    plot_theta_convergence_L2(theta_diffs_stand_all, theta_mean_diffs_stand_all, mode)
    plot_theta_convergence_Linf(theta_diffs_inf_all, theta_mean_diffs_inf_all, mode)


def calculate_standardized_regrets(all_results):
    """
    Calculates standardized cumulative regrets from all runs.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.

    Returns:
        np.ndarray: Array of standardized cumulative regrets for all runs.
    """
    standardized_regrets = []
    for run in all_results:
        max_obtainable = run["max_obtainable"]
        rewards_gt = np.concatenate(run["rewards_gt"]).flatten()
        regret = max_obtainable - rewards_gt
        cumulative_regret = np.cumsum(regret)
        standardized_regrets.append(cumulative_regret)
    standardized_regrets = np.array(standardized_regrets)
    return standardized_regrets


def plot_cumulative_regret(standardized_regrets, mode=""):
    """
    Plots the cumulative regret with mean and standard deviation.

    Args:
        standardized_regrets (np.ndarray): Array of standardized cumulative regrets.
    """
    plt.figure(figsize=(10, 6))
    mean_regret = np.mean(standardized_regrets, axis=0)
    std_regret = np.std(standardized_regrets, axis=0)
    iterations = np.arange(len(mean_regret))

    plt.plot(iterations, mean_regret, label="Mean Regret")
    plt.fill_between(
        iterations,
        mean_regret + std_regret,
        mean_regret - std_regret,
        alpha=0.3,
        label=r"$\pm$ std",
    )

    plt.xlabel("Iteration")
    plt.ylabel("Cumulative Regret")
    plt.legend()
    plt.savefig(f"cumulative_regret_{mode}.png")
    plt.close()


def calculate_standardized_rewards(all_results):
    """
    Calculates standardized rewards and mean maximum obtainable reward.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.

    Returns:
        Tuple[np.ndarray, float]: Array of standardized rewards and mean of maximum obtainable rewards.
    """
    standardized_rewards = []
    for run in all_results:
        rewards_gt = np.concatenate(run["rewards_gt"]).flatten()
        standardized_rewards.append(rewards_gt)
    standardized_rewards = np.array(standardized_rewards)
    mean_max = np.mean([run["max_obtainable"] for run in all_results])
    return standardized_rewards, mean_max


def plot_rewards_convergence(standardized_rewards, mean_max, mode=""):
    """
    Plots the reward convergence over iterations.

    Args:
        standardized_rewards (np.ndarray): Array of standardized rewards.
        mean_max (float): Mean of maximum obtainable rewards across runs.
    """
    plt.figure(figsize=(10, 6))
    mean_rewards = np.mean(standardized_rewards, axis=0)
    std_rewards = np.std(standardized_rewards, axis=0)
    iterations = np.arange(len(mean_rewards))

    plt.plot(iterations, mean_rewards, label="Mean Rewards")
    plt.fill_between(
        iterations,
        mean_rewards + std_rewards,
        mean_rewards - std_rewards,
        alpha=0.2,
        label=r"$\pm$ std",
    )

    plt.plot(
        [0, len(mean_rewards) - 1],
        [mean_max, mean_max],
        "r--",
        label=r"$E_{\theta \sim q}[\theta^T x^*]$",
    )

    plt.xlabel("Iteration")
    plt.ylabel("Reward")
    plt.legend()
    plt.savefig(f"reward_convergence_{mode}.png")
    plt.close()


def calculate_theta_differences(all_results, projector):
    """
    Calculates standardized differences between estimated and true theta values.

    Args:
        all_results (List[Dict]): List of result dictionaries from multiple runs.
        projector (Tensor): The projector used in the algorithm.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Arrays of standardized differences for L2 and L∞ norms.
    """
    projector_np = projector.numpy()
    theta_diffs_stand_all = []
    theta_mean_diffs_stand_all = []
    theta_diffs_inf_all = []
    theta_mean_diffs_inf_all = []

    for run in all_results:
        theta_iterates = np.array(run["theta_iterate"]) @ projector_np
        theta_gt = np.array(run["theta_gt"]) @ projector_np
        posterior_means = np.array(run["posterior_mean"]) @ projector_np

        # Compute L2 differences
        theta_diffs = np.linalg.norm(theta_iterates - theta_gt, axis=1)
        theta_diffs_stand = theta_diffs / np.linalg.norm(theta_gt)
        theta_mean_diffs = np.linalg.norm(posterior_means - theta_gt, axis=1)
        theta_mean_diffs_stand = theta_mean_diffs / np.linalg.norm(theta_gt)

        theta_diffs_stand_all.append(theta_diffs_stand)
        theta_mean_diffs_stand_all.append(theta_mean_diffs_stand)

        # Compute L∞ differences
        theta_diffs_inf = np.linalg.norm(theta_iterates - theta_gt, ord=np.inf, axis=1)
        theta_diffs_stand_inf = theta_diffs_inf / np.linalg.norm(theta_gt, ord=np.inf)
        theta_mean_diffs_inf = np.linalg.norm(
            posterior_means - theta_gt, ord=np.inf, axis=1
        )
        theta_mean_diffs_inf_stand = theta_mean_diffs_inf / np.linalg.norm(
            theta_gt, ord=np.inf
        )

        theta_diffs_inf_all.append(theta_diffs_stand_inf)
        theta_mean_diffs_inf_all.append(theta_mean_diffs_inf_stand)

    # Convert lists to numpy arrays
    theta_diffs_stand_all = np.array(theta_diffs_stand_all)
    theta_mean_diffs_stand_all = np.array(theta_mean_diffs_stand_all)
    theta_diffs_inf_all = np.array(theta_diffs_inf_all)
    theta_mean_diffs_inf_all = np.array(theta_mean_diffs_inf_all)

    return (
        theta_diffs_stand_all,
        theta_mean_diffs_stand_all,
        theta_diffs_inf_all,
        theta_mean_diffs_inf_all,
    )


def plot_theta_convergence_L2(
    theta_diffs_stand_all, theta_mean_diffs_stand_all, mode=""
):
    """
    Plots the convergence of theta estimates in L2 norm.

    Args:
        theta_diffs_stand_all (np.ndarray): Standardized differences for theta_t in L2 norm.
        theta_mean_diffs_stand_all (np.ndarray): Standardized differences for mu_t in L2 norm.
    """
    theta_diffs_stand_mean = np.mean(theta_diffs_stand_all, axis=0)
    theta_diffs_stand_std = np.std(theta_diffs_stand_all, axis=0)
    theta_mean_diffs_stand_mean = np.mean(theta_mean_diffs_stand_all, axis=0)
    theta_mean_diffs_stand_std = np.std(theta_mean_diffs_stand_all, axis=0)
    iterations = np.arange(len(theta_diffs_stand_mean))

    plt.figure(figsize=(10, 6))

    # Plot for theta_t
    plt.plot(
        iterations,
        theta_diffs_stand_mean,
        label=r"$\hat{\theta} = \theta_t$",
        color="blue",
    )
    plt.fill_between(
        iterations,
        theta_diffs_stand_mean - theta_diffs_stand_std,
        theta_diffs_stand_mean + theta_diffs_stand_std,
        color="blue",
        alpha=0.2,
    )

    # Plot for mu_t
    plt.plot(
        iterations,
        theta_mean_diffs_stand_mean,
        label=r"$\hat{\theta} = \mu_t$",
        color="orange",
        linestyle="--",
    )
    plt.fill_between(
        iterations,
        theta_mean_diffs_stand_mean - theta_mean_diffs_stand_std,
        theta_mean_diffs_stand_mean + theta_mean_diffs_stand_std,
        color="orange",
        alpha=0.2,
    )

    plt.xlabel("Iteration")
    plt.ylabel(
        r"$\frac{\|\Pi_V \theta_* - \Pi_V \hat{\theta}\|_2}{\|\Pi_V \theta_*\|_2}$",
        fontsize=20,
    )
    plt.legend()
    plt.grid(True)
    plt.savefig(f"L2_theta_convergence_{mode}.png")
    plt.close()


def plot_theta_convergence_Linf(theta_diffs_inf_all, theta_mean_diffs_inf_all, mode=""):
    """
    Plots the convergence of theta estimates in L∞ norm.

    Args:
        theta_diffs_inf_all (np.ndarray): Standardized differences for theta_t in L∞ norm.
        theta_mean_diffs_inf_all (np.ndarray): Standardized differences for mu_t in L∞ norm.
    """
    theta_diffs_inf_mean = np.mean(theta_diffs_inf_all, axis=0)
    theta_diffs_inf_std = np.std(theta_diffs_inf_all, axis=0)
    theta_mean_diffs_inf_mean = np.mean(theta_mean_diffs_inf_all, axis=0)
    theta_mean_diffs_inf_std = np.std(theta_mean_diffs_inf_all, axis=0)
    iterations = np.arange(len(theta_diffs_inf_mean))

    plt.figure(figsize=(10, 6))

    # Plot for theta_t
    plt.plot(
        iterations,
        theta_diffs_inf_mean,
        label=r"$\hat{\theta} = \theta_t$",
        color="green",
    )
    plt.fill_between(
        iterations,
        theta_diffs_inf_mean - theta_diffs_inf_std,
        theta_diffs_inf_mean + theta_diffs_inf_std,
        color="green",
        alpha=0.2,
    )

    # Plot for mu_t
    plt.plot(
        iterations,
        theta_mean_diffs_inf_mean,
        label=r"$\hat{\theta} = \mu_t$",
        color="red",
        linestyle="--",
    )
    plt.fill_between(
        iterations,
        theta_mean_diffs_inf_mean - theta_mean_diffs_inf_std,
        theta_mean_diffs_inf_mean + theta_mean_diffs_inf_std,
        color="red",
        alpha=0.2,
    )

    plt.xlabel("Iteration")
    plt.ylabel(
        r"$\frac{\|\Pi_V \theta_* - \Pi_V \hat{\theta}\|_\infty}{\|\Pi_V \theta_*\|_\infty}$",
        fontsize=20,
    )
    plt.legend()
    plt.grid(True)
    plt.savefig(f"Linf_theta_convergence_{mode}.png")
    plt.close()
